"""
Flow API Server - WebSocket 브릿지
브라우저의 Tampermonkey 스크립트와 WebSocket으로 통신

사용법:
1. pip install fastapi uvicorn websockets
2. python flow_api_server.py
3. Tampermonkey 스크립트 설치 후 Flow 페이지 열기
4. API 호출: POST http://localhost:8002/generate
"""

import asyncio
import json
import uuid
from typing import Optional
from datetime import datetime

from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn


class GenerateRequest(BaseModel):
    prompt: str = Field(..., description="이미지 생성 프롬프트")
    count: int = Field(default=2, ge=1, le=8, description="생성할 이미지 수")
    model: str = Field(default="GEM_PIX_2", description="모델: IMAGEN4, GEM_PIX, GEM_PIX_2")
    aspect_ratio: str = Field(default="landscape", description="비율: square, landscape, portrait")


class GenerateResponse(BaseModel):
    success: bool
    count: Optional[int] = None
    images: list = Field(default_factory=list)
    error: Optional[str] = None


app = FastAPI(
    title="Flow API Server",
    description="Google Flow 이미지 생성 API (WebSocket 브릿지)",
    version="1.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class ConnectionManager:
    def __init__(self):
        self.browser_connection: Optional[WebSocket] = None
        self.pending_requests: dict = {}  # request_id -> asyncio.Future

    async def connect_browser(self, websocket: WebSocket):
        await websocket.accept()
        self.browser_connection = websocket
        print(f"[WS] Browser connected")

    def disconnect_browser(self):
        self.browser_connection = None
        # 대기 중인 모든 요청 실패 처리
        for req_id, future in self.pending_requests.items():
            if not future.done():
                future.set_result({"success": False, "error": "Browser disconnected"})
        self.pending_requests.clear()
        print(f"[WS] Browser disconnected")

    async def send_to_browser(self, message: dict) -> dict:
        if not self.browser_connection:
            raise HTTPException(status_code=503, detail="브라우저가 연결되지 않았습니다. Flow 페이지를 열어주세요.")

        request_id = str(uuid.uuid4())
        message["requestId"] = request_id

        # Future 생성하여 응답 대기
        loop = asyncio.get_event_loop()
        future = loop.create_future()
        self.pending_requests[request_id] = future

        try:
            await self.browser_connection.send_json(message)
            # 60초 타임아웃
            result = await asyncio.wait_for(future, timeout=60.0)
            return result
        except asyncio.TimeoutError:
            raise HTTPException(status_code=504, detail="요청 타임아웃 (60초)")
        finally:
            self.pending_requests.pop(request_id, None)

    def handle_response(self, request_id: str, result: dict):
        future = self.pending_requests.get(request_id)
        if future and not future.done():
            future.set_result(result)


manager = ConnectionManager()


@app.get("/")
async def root():
    return {
        "service": "Flow API Server",
        "status": "running",
        "browser_connected": manager.browser_connection is not None,
        "docs": "/docs"
    }


@app.get("/health")
async def health():
    return {
        "status": "ok",
        "browser_connected": manager.browser_connection is not None,
        "timestamp": datetime.now().isoformat()
    }


@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
    """이미지 생성 API"""

    if not manager.browser_connection:
        raise HTTPException(
            status_code=503,
            detail="브라우저가 연결되지 않았습니다. Flow 페이지(https://labs.google/fx/tools/flow)를 열고 로그인하세요."
        )

    print(f"[API] Generate request: {request.prompt[:50]}...")

    result = await manager.send_to_browser({
        "action": "generate",
        "params": {
            "prompt": request.prompt,
            "count": request.count,
            "model": request.model,
            "aspectRatio": request.aspect_ratio
        }
    })

    return GenerateResponse(**result)


@app.post("/generate/batch")
async def generate_batch(prompts: list[str], count: int = 2, model: str = "GEM_PIX_2", aspect_ratio: str = "landscape"):
    """배치 이미지 생성 API"""

    if not manager.browser_connection:
        raise HTTPException(status_code=503, detail="브라우저가 연결되지 않았습니다.")

    results = []
    for i, prompt in enumerate(prompts):
        print(f"[API] Batch {i+1}/{len(prompts)}: {prompt[:30]}...")

        result = await manager.send_to_browser({
            "action": "generate",
            "params": {
                "prompt": prompt,
                "count": count,
                "model": model,
                "aspectRatio": aspect_ratio
            }
        })
        results.append(result)

        # 요청 간 딜레이
        if i < len(prompts) - 1:
            await asyncio.sleep(2)

    return {"results": results}


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """브라우저 연결용 WebSocket"""

    await manager.connect_browser(websocket)

    try:
        while True:
            data = await websocket.receive_json()

            # 브라우저에서 응답 받음
            if "requestId" in data and "result" in data:
                manager.handle_response(data["requestId"], data["result"])

            # 핑/퐁
            elif data.get("type") == "ping":
                await websocket.send_json({"type": "pong"})

    except WebSocketDisconnect:
        manager.disconnect_browser()
    except Exception as e:
        print(f"[WS] Error: {e}")
        manager.disconnect_browser()


if __name__ == "__main__":
    print("\n" + "="*60)
    print("Flow API Server")
    print("="*60)
    print("\n1. 이 서버를 실행한 상태로 유지")
    print("2. Chrome에서 Flow 페이지 열기: https://labs.google/fx/tools/flow")
    print("3. Tampermonkey 스크립트가 자동으로 WebSocket 연결")
    print("4. API 사용: POST http://localhost:8002/generate")
    print("\nAPI 문서: http://localhost:8002/docs")
    print("="*60 + "\n")

    uvicorn.run(app, host="0.0.0.0", port=8002)
